# Copyright 2020 The SQLFlow Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os


def _gen_pai_local_method(name):
    def impl(*args, **kwargs):
        import runtime.pai as pai
        method = getattr(pai, name)
        os.environ["SQLFLOW_submitter"] = "pai_local"
        return method(*args, **kwargs)

    return impl


train = _gen_pai_local_method('train')
pred = _gen_pai_local_method('pred')
evaluate = _gen_pai_local_method('evaluate')
explain = _gen_pai_local_method('explain')
